package top.cyuw.simplerpc.extension;

import lombok.extern.slf4j.Slf4j;
import top.cyuw.simplerpc.annotation.SPI;
import top.cyuw.simplerpc.util.StringUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author chen
 * @date 2023/3/13 10:40
 */
@Slf4j
public class ExtensionLoader<T> {

    private static final String EXTENSIONS_DIRECTORY = "META-INF/extensions/";
    private static final Map<Class<?>, ExtensionLoader<?>> EXTENSION_LOADERS = new ConcurrentHashMap<>();
    private static final Map<Class<?>, Object> EXTENSION_INSTANCES = new ConcurrentHashMap<>();

    private final Class<?> type;
    private final Map<String, Object> instances = new ConcurrentHashMap<>();
    private volatile Map<String, Class<?>> cachedClasses;

    private ExtensionLoader(Class<?> type) {
        this.type = type;
    }

    public static <E> ExtensionLoader<E> of(Class<E> type) {
        if (type == null) {
            throw new IllegalArgumentException("extension type should not be null.");
        }
        if (!type.isInterface()) {
            throw new IllegalArgumentException("extension type must be an interface.");
        }
        if (!type.isAnnotationPresent(SPI.class)) {
            throw new IllegalArgumentException("extension type must be annotated by @SPI");
        }
        return (ExtensionLoader<E>) EXTENSION_LOADERS.computeIfAbsent(type, k -> new ExtensionLoader<E>(type));
    }

    public T getExtension(String name) {
        if (StringUtils.isEmpty(name)) {
            throw new IllegalArgumentException("extension type should not be null.");
        }
        return (T) instances.computeIfAbsent(name, k -> createExtension(name));
    }

    private Object createExtension(String name) {
        Class<?> clazz = getExtensionClasses().get(name);
        if (clazz == null) {
            throw new RuntimeException("No such extension of name " + name);
        }
        return EXTENSION_INSTANCES.computeIfAbsent(clazz, k -> {
            try {
                return clazz.newInstance();
            } catch (Exception e) {
                log.error("create extension failed: " + e.getMessage(), e);
            }
            return null;
        });
    }

    private Map<String, Class<?>> getExtensionClasses() {
        // double check
        if (cachedClasses == null) {
            synchronized (type) {
                if (cachedClasses == null) {
                    cachedClasses = loadDirectory();
                }
            }
        }
        return cachedClasses;
    }

    private Map<String, Class<?>> loadDirectory() {
        Map<String, Class<?>> classes = new HashMap<>();

        String fileName = ExtensionLoader.EXTENSIONS_DIRECTORY + type.getName();
        try {
            Enumeration<URL> urls;
            ClassLoader classLoader = ExtensionLoader.class.getClassLoader();
            urls = classLoader.getResources(fileName);
            if (urls != null) {
                while (urls.hasMoreElements()) {
                    URL resourceUrl = urls.nextElement();
                    loadResource(classes, classLoader, resourceUrl);
                }
            }
        } catch (IOException e) {
            log.error("load extensions failed: " + e.getMessage(), e);
        }

        return classes;
    }

    private void loadResource(Map<String, Class<?>> classes, ClassLoader classLoader, URL resourceUrl) {
        try (BufferedReader reader = new BufferedReader(new InputStreamReader(resourceUrl.openStream(), StandardCharsets.UTF_8))) {
            String line;
            while ((line = reader.readLine()) != null) {
                line = line.trim();
                if (StringUtils.isNotEmpty(line)) {
                    int splitter = line.indexOf('=');
                    if (splitter > 0) {
                        String name = line.substring(0, splitter).trim();
                        String className = line.substring(splitter + 1).trim();
                        if (StringUtils.isNotEmpty(name) && StringUtils.isNotEmpty(className)) {
                            try {
                                Class<?> clazz = classLoader.loadClass(className);
                                classes.put(name, clazz);
                            } catch (ClassNotFoundException e) {
                                log.debug("extension class not found: " + e.getMessage(), e);
                            }
                        }
                    }
                }

            }
        } catch (IOException e) {
            log.error("load extension failed: " + e.getMessage(), e);
        }
    }

}
